import os
import gc
import random
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.utils.data as Data
from torch.utils.data.dataset import TensorDataset

def load_data(path, batch_size = 128):
    t_train = torch.FloatTensor(np.load(os.path.join(path, 'train_t.npy')))
    feat_train = torch.FloatTensor(np.load(os.path.join(path, 'train_x.npy')))
    edges_train = torch.FloatTensor(np.load(os.path.join(path, 'train_y.npy')))

    t_test = torch.FloatTensor(np.load(os.path.join(path, 'test_t.npy')))
    feat_test = torch.FloatTensor(np.load(os.path.join(path, 'test_x.npy')))
    edges_test = torch.FloatTensor(np.load(os.path.join(path, 'test_y.npy')))

    t_length = t_train.shape[1]

    for i in range(1, t_length):
        t_train[:, -i] = t_train[:, -i] - t_train[:, -i-1]
        t_test[:, -i] = t_test[:, -i] - t_test[:, -i-1]
    t_train[:, 0] = 1
    t_test[:, 0] = 1

    train_data = TensorDataset(t_train, feat_train, edges_train)
    test_data = TensorDataset(t_test, feat_test, edges_test)

    train_data_loader = Data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
    test_data_loader = Data.DataLoader(test_data, batch_size=batch_size)

    return train_data_loader, test_data_loader

class CTRNN(nn.Module):
    def __init__(self, device, input_size: int, hidden_size: int, output_size: int, num_unfolds: int = 3, tau: int = 1):
        super(CTRNN, self).__init__()
        self.device = device
        self.units = hidden_size
        self.state_size = hidden_size
        self.num_unfolds = num_unfolds
        self.tau = tau
        
        self.kernel = nn.Linear(input_size, hidden_size)
        self.recurrent_kernel = nn.Linear(hidden_size, hidden_size, bias=False)
        self.scale = nn.Parameter(torch.ones(hidden_size))

        self.decoder = nn.Linear(hidden_size, output_size)


    def forward(self, t, x):
        # t.shape == (batch_size, n_take)
        # x.shape == (batch_size, n_take, input_size)
        hidden_state = torch.zeros((t.size(0), self.units)).to(self.device)
        for i in range(t.size(1)):
            delta_t = t[:, i] / self.num_unfolds
            for _ in range(self.num_unfolds):
                hidden_state = self.euler(x[:, i, :], hidden_state, delta_t)
        return self.decoder(hidden_state)

    def dfdt(self, inputs, hidden_state):
        dh_in = self.scale * (self.kernel(inputs) + self.recurrent_kernel(hidden_state)).tanh()
        if self.tau > 0:
            dh = dh_in - hidden_state * self.tau
        else:
            dh = dh_in
        return dh

    def euler(self, inputs, hidden_state, delta_t):
        return hidden_state + (delta_t * self.dfdt(inputs, hidden_state).permute(1, 0)).permute(1, 0)

def main(version, hidden_size):
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms = True

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    
    num_epochs = 50
    batch_size = 128

    train_loader, test_loader = load_data('irregular_spring', batch_size)

    print(device)

    model = CTRNN(device, input_size=20, hidden_size=hidden_size, output_size=20).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = num_epochs * len(train_loader), eta_min = 0.00005, last_epoch = -1)

    criterion = nn.BCEWithLogitsLoss()

    train_loss = []
    train_acc = []
    test_loss = []
    test_acc = []

    for epoch in range(num_epochs):
        print(version, 'Epoch {}/{}'.format(epoch + 1, num_epochs))
        epoch_loss = 0
        epoch_corrects = 0
        num_sample = 0
        model.train()
        for t, x, y in tqdm(train_loader):
            t, x, y = t.to(device), x.to(device), y.to(device)
            output = model(t, x)
            loss = criterion(output, y)
            loss.backward()

            epoch_corrects += int(torch.sum((output > 0).int() == y))
            epoch_loss += loss.item() * x.size(0)
            num_sample += x.size(0) * y.size(1)

            optimizer.step()
            optimizer.zero_grad()
            scheduler.step()

        train_loss.append(epoch_loss / num_sample)
        train_acc.append(epoch_corrects / num_sample)
        print(' ', train_loss[-1], train_acc[-1])

        epoch_loss = 0
        epoch_corrects = 0
        num_sample = 0
        model.eval()
        with torch.no_grad():
            for t, x, y in tqdm(test_loader):
                t, x, y = t.to(device), x.to(device), y.to(device)
                output = model(t, x)
                loss = criterion(output, y)

                epoch_corrects += int(torch.sum((output > 0).int() == y))
                epoch_loss += loss.item() * x.size(0)
                num_sample += x.size(0) * y.size(1)

        test_loss.append(epoch_loss / num_sample)
        test_acc.append(epoch_corrects / num_sample)
        print(' ', test_loss[-1], test_acc[-1])

        torch.save(model, f'{version}.pkl')

        try:
            pd.DataFrame({'Train Loss': train_loss, 'Train Acc': train_acc}).to_csv(f'{version}_Train.csv')
        except:
            print('Fail to save the file Train.csv')
            pd.DataFrame({'Train Loss': train_loss, 'Train Acc': train_acc}).to_csv(f'{version}_Train_1.csv')

        try:
            pd.DataFrame({'Test Loss': test_loss, 'Test Acc': test_acc}).to_csv(f'{version}_Test.csv')
        except:
            print('Fail to save the file Test.csv')
            pd.DataFrame({'Test Loss': test_loss, 'Test Acc': test_acc}).to_csv(f'{version}_Test_1.csv')

    gc.collect()

if __name__ == '__main__':
    for hidden_size in [128, 256, 512]:
        file_version = f'CTRNN_IrrSpring_{hidden_size}'
        main(file_version, hidden_size)